package org.example.utils

import kafka.api.{OffsetRequest, PartitionOffsetRequestInfo}
import kafka.common.TopicAndPartition
import kafka.consumer.SimpleConsumer
import kafka.utils.ZkUtils
import org.I0Itec.zkclient.ZkClient
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.SecurityProtocol
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.kafka010.OffsetRange
import org.example.constant.ApolloConst
import org.slf4j.LoggerFactory

import java.sql.{Connection, DriverManager, PreparedStatement}

/**
 * 偏移量工具类
 */
object OffsetManager extends Serializable {
  private val logger = LoggerFactory.getLogger(OffsetManager.getClass)
  lazy val zkClient:ZkClient = ZkUtils.createZkClient(ApolloConst.zkKafka, 30000, 30000)
  sys.addShutdownHook{
    logger.info("Execute hook thread: ZkManager")
    zkClient.close()
  }
  lazy val zkUtils:ZkUtils = ZkUtils.apply(zkClient,isZkSecurityEnabled = false)
  //获取MySQL链接
  def getConn = {
    Class.forName("com.mysql.jdbc.Driver")
    DriverManager.getConnection(ApolloConst.jgdMysqlURL, ApolloConst.jgdMysqlUserName, ApolloConst.jgdMysqlPassWord)
  }

  //获取偏移量信息
  def apply(groupid: String, topic: Array[String]) = {
    val conn = getConn
    import scala.collection.mutable._
    var offsetRange = Map[TopicPartition, Long]()
    val statement = conn.prepareStatement("select * from `offset_db`.kafka_offset where groupid = ? and topic = ?")
    for (elem <- topic) {
      statement.setString(1, groupid)
      statement.setString(2, elem)
      val rs = statement.executeQuery()
      while (rs.next()) {
        //放入map中
        val tp = new TopicPartition(rs.getString("topic"), rs.getInt("partition"))
        val kafkaOffset = getOffsetForKafka(tp)
        //判断和kafka里面的offset进行对比
        if (kafkaOffset < rs.getLong("untiloffset")) {
          offsetRange += tp -> rs.getLong("untiloffset")
        } else {
          offsetRange += tp -> kafkaOffset
        }
      }
      rs.close()
    }
    statement.close()
    conn.close()
    offsetRange
  }

  //保存当前偏移量到MySQL
  def saveCurrentBatchOffset(groupid: String, offsetRange: Array[OffsetRange], ssc: StreamingContext) = {
    var conn: Connection = null
    var statement: PreparedStatement = null
    try {
      conn = getConn
      //不自动提交事务
      conn.setAutoCommit(false)
      statement = conn.prepareStatement("replace into `offset_db`.kafka_offset values(?,?,?,?)")
      for (i <- offsetRange) {
        statement.setString(1, groupid)
        statement.setString(2, i.topic)
        statement.setInt(3, i.partition)
        statement.setLong(4, i.untilOffset)

        statement.executeUpdate()
      }
      //提交事务
      conn.commit()
    } catch {
      case e: Exception => {
        logger.info("即将回滚事务终止任务！")
        //回滚事务
        conn.rollback()
        //终止任务
        ssc.stop()
      }
    } finally {
      if (statement != null) {
        statement.close()
      }
      if (conn != null) {
        conn.close()
      }
    }
  }

  //从kafka获取偏移量
  def getOffsetForKafka(topicPartition: TopicPartition, time: Long = OffsetRequest.EarliestTime): Long = {
    val brokerId = zkUtils.getLeaderForPartition(topicPartition.topic,topicPartition.partition).get
    val broker = zkUtils.getBrokerInfo(brokerId).get
    val endpoint = broker.getBrokerEndPoint(SecurityProtocol.PLAINTEXT)
    val consumer = new SimpleConsumer(endpoint.host,endpoint.port,10000,100000,"getOffset")
    val tp = TopicAndPartition(topicPartition.topic,topicPartition.partition)
    val request= OffsetRequest(Map(tp -> PartitionOffsetRequestInfo(time,1)))
    consumer.getOffsetsBefore(request).partitionErrorAndOffsets(tp).offsets.head
  }
}
